from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(  model_path, 
device_map="auto", 
torch_dtype=torch.bfloat16)
model.enable_input_require_grads()  # 开启梯度检查点时，要执行该方法
print(model)
